{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 安装第三方库\n",
    "!pip install torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.]])\n",
      "tensor([[1., 1., 1.],\n",
      "        [1., 1., 1.]])\n",
      "tensor([[0.8090, 0.7935, 0.2099, 0.9279],\n",
      "        [0.8136, 0.7422, 0.4769, 0.4955],\n",
      "        [0.3602, 0.1178, 0.7852, 0.0228]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# 创建张量（tensor）\n",
    "## 使用封装的函数创建张量\n",
    "zeros = torch.zeros(2, 3)\n",
    "print(zeros)\n",
    "\n",
    "ones = torch.ones(2, 3)\n",
    "print(ones)\n",
    "\n",
    "torch.manual_seed(1024)\n",
    "random = torch.rand(3, 4)\n",
    "print(random)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2, 3, 4],\n",
      "        [1, 0, 1]])\n",
      "tensor([[2, 3, 4],\n",
      "        [1, 0, 1]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 创建张量（tensor）\n",
    "## 从Python对象创建\n",
    "data = [[2, 3, 4], [1, 0, 1]]\n",
    "t_data = torch.tensor(data)\n",
    "print(t_data)\n",
    "\n",
    "## 从Numpy对象创建\n",
    "import numpy as np\n",
    "\n",
    "n_data = np.array(data)\n",
    "tn_data = torch.from_numpy(n_data)\n",
    "print(tn_data)\n",
    "\n",
    "## Numpy bridge，也就是对numpy对象的改变会传导到张量\n",
    "n_data += 1\n",
    "torch.all(torch.from_numpy(n_data) == tn_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 4])\n",
      "torch.Size([1, 3, 4])\n",
      "torch.Size([3, 4])\n",
      "tensor(True)\n",
      "False\n"
     ]
    }
   ],
   "source": [
    "# 变换张量维度\n",
    "## 增加或减少数据的维度\n",
    "a = torch.rand(3, 4)\n",
    "print(a.shape)\n",
    "## 增加维度\n",
    "b = a.unsqueeze(0)\n",
    "print(b.shape)\n",
    "## 减少维度\n",
    "c = b.squeeze(0)\n",
    "print(c.shape)\n",
    "## 数据相同，但是维度不同\n",
    "print(torch.all(c.eq(b)))\n",
    "print(c.shape == b.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) torch.Size([10])\n",
      "tensor([[0, 1, 2, 3, 4],\n",
      "        [5, 6, 7, 8, 9]])\n",
      "tensor([[0, 5],\n",
      "        [1, 6],\n",
      "        [2, 7],\n",
      "        [3, 8],\n",
      "        [4, 9]])\n"
     ]
    }
   ],
   "source": [
    "# 变换张量形状\n",
    "data = torch.tensor(range(0, 10))\n",
    "print(data, data.shape)\n",
    "view1 = data.view(2, 5)\n",
    "print(view1)\n",
    "transpose1 = view1.T\n",
    "print(transpose1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True False\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-a26f66520012>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m## 非毗邻存储（contiguous）的对象不能进行view操作\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mview1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_contiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtranspose1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_contiguous\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mview2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtranspose1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead."
     ]
    }
   ],
   "source": [
    "## 非毗邻存储（contiguous）的对象不能进行view操作\n",
    "print(view1.is_contiguous(), transpose1.is_contiguous())\n",
    "view2 = transpose1.view(1, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[2., 2.],\n",
      "        [2., 2.]])\n",
      "tensor([[ 2.,  4.],\n",
      "        [ 8., 16.]])\n"
     ]
    }
   ],
   "source": [
    "# 逐元素操作（element-wise operations）\n",
    "twos = torch.ones(2, 2) * 2\n",
    "print(twos)\n",
    "\n",
    "powers = twos ** torch.tensor([[1, 2], [3, 4]])\n",
    "print(powers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2, 3],\n",
      "        [4, 5, 6]])\n",
      "tensor([1, 2, 3])\n",
      "tensor([[ 1,  4,  9],\n",
      "        [ 4, 10, 18]])\n",
      "torch.Size([4, 5, 3, 2])\n"
     ]
    }
   ],
   "source": [
    "## 广播机制（tensor broadcasting）\n",
    "a = torch.tensor(range(1, 7)).view(2, 3)\n",
    "b = torch.tensor(range(1, 4)).view(   3)\n",
    "print(a)\n",
    "print(b)\n",
    "print(a * b)\n",
    "\n",
    "## 关于广播，更复杂的例子\n",
    "a =     torch.ones(4, 1, 3, 2)\n",
    "b = a * torch.rand(   5, 1, 2)\n",
    "print(b.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 5])\n",
      "torch.Size([5, 8, 3, 5])\n"
     ]
    }
   ],
   "source": [
    "# 矩阵运算\n",
    "mat1 = torch.randn(3, 4)\n",
    "mat2 = torch.randn(4, 5)\n",
    "re = mat1 @ mat2\n",
    "print(re.shape)\n",
    "## 矩阵运算的广播\n",
    "mat1 = torch.randn(5, 1, 3, 4)\n",
    "mat2 = torch.randn(   8, 4, 5)\n",
    "re = mat1 @ mat2\n",
    "print(re.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([])\n",
      "torch.Size([3])\n",
      "torch.Size([10, 3])\n"
     ]
    }
   ],
   "source": [
    "# 向量运算\n",
    "# 向量与向量\n",
    "vec1 = torch.randn(3)\n",
    "vec2 = torch.randn(3)\n",
    "print((vec1 @ vec2).shape)\n",
    "# 矩阵与向量\n",
    "mat = torch.randn(3, 4)\n",
    "vec = torch.randn(4)\n",
    "print((mat @ vec).shape)\n",
    "# 张量与向量\n",
    "mat = torch.randn(10, 3, 4)\n",
    "vec = torch.randn(4)\n",
    "print((mat @ vec).shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
